import os
from typing import Dict, Iterable, List
import numpy as np
import torch.optim as optim

from dataset import *

from victim_model.text_classifier import TextClassifier
from config import Config
from tools.logger import Logger
from tools.saver import Saver
from tools.color import Color
from tools.optimizer import AdamY
from tools.device_manager import DeviceManager

from allennlp.data import Vocabulary, DataLoader, Instance
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training import Checkpointer
from allennlp.training.learning_rate_schedulers import SlantedTriangular
from allennlp.models import Model

cf = Config()
saver = Saver(cf.model_id, d_ckpt=cf.d_ckpt)
logger = Logger(cf.p_log['train'], quiet=cf.quiet)
print(cf)


def build_dataset_reader():
    # dataset
    if cf.dataset == 'imdb':
        reader = IMDBDatasetReader(cf)
    elif cf.dataset == 'agnews':
        reader = AGNewsDatasetReader(cf)
    elif cf.dataset == 'mr':
        reader = MRDatasetReader(cf)
    else:
        raise ValueError('dataset error')
    return reader


def build_vocab(instances: Iterable[Instance]) -> Vocabulary:
    vocab = Vocabulary().from_instances(instances, max_vocab_size=cf.max_vocab_size)
    return vocab


def build_model(vocab: Vocabulary) -> Model:
    model = TextClassifier(cf, vocab)
    model = model.cuda()
    return model


def build_data_loader(train_dataset, test_dataset):
    train_loader = DataLoader(train_dataset, batch_size=cf.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=cf.batch_size, shuffle=False)
    return train_loader, test_loader


def build_trainer(train_loader, test_loader, model):
    if 'bert' in cf.encoder:
        optimizer = AdamY(model, lr=2e-5)
    else:
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    lr_schedular = SlantedTriangular(optimizer, num_epochs=cf.epoch, cut_frac=0.1, ratio=100)

    trainer = GradientDescentTrainer(
        model=model,
        optimizer=optimizer,
        # checkpointer=Checkpointer(serialization_dir=saver.p_ckpt,
        #                           num_serialized_models_to_keep=-1),
        learning_rate_scheduler=lr_schedular,
        data_loader=train_loader,
        validation_data_loader=test_loader,
        validation_metric='+accuracy',
        # patience=10,
        num_epochs=cf.epoch,
        cuda_device=0
    )
    return trainer


def main():
    # read data
    logger.print(Color.green('reading data...'))
    reader = build_dataset_reader()
    train_dataset = reader.read(cf.p_split['train'])

    test_dataset = reader.read(cf.p_split['test'])

    logger.print(Color.green('building vocabulary...'))
    vocab = build_vocab(train_dataset)

    train_dataset.index_with(vocab)
    test_dataset.index_with(vocab)

    train_loader, test_loader = build_data_loader(train_dataset, test_dataset)

    logger.print(Color.green('training...'))
    model = build_model(vocab)
    trainer = build_trainer(train_loader, test_loader, model)
    metric = trainer.train()

    result = f'[{cf.model_id} {cf.d_ckpt}] ' \
             f'best_epoch: {metric["best_epoch"]} ' \
             f'best_test_accuracy: {metric["best_validation_accuracy"]} ' \
             f'last_test_accuracy: {metric["validation_accuracy"]}'
    logger.print(Color.green(result))
    logger.log(result)
    saver.save(model)


if __name__ == '__main__':
    with DeviceManager(cf.device):
        main()
